import math
import pandas as pd
from datetime import datetime
from catboost import CatBoostClassifier

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from hmmlearn import hmm, vhmm

from bots.botlibs.labeling_lib import *
from bots.botlibs.tester_lib import tester_one_direction
from bots.botlibs.export_lib import export_model_to_ONNX

def get_prices() -> pd.DataFrame:
    p = pd.read_csv('files/'+hyper_params['symbol']+'.csv', sep='\s+')
    pFixed = pd.DataFrame(columns=['time', 'close'])
    pFixed['time'] = p['<DATE>'] + ' ' + p['<TIME>']
    pFixed['time'] = pd.to_datetime(pFixed['time'], format='mixed')
    pFixed['close'] = p['<CLOSE>']
    pFixed.set_index('time', inplace=True)
    pFixed.index = pd.to_datetime(pFixed.index, unit='s')
    return pFixed.dropna()

def get_features(data: pd.DataFrame) -> pd.DataFrame:
    pFixed = data.copy()
    pFixedC = data.copy()
    count = 0

    for i in hyper_params['periods']:
        pFixed[str(count)] = pFixedC.rolling(i).std()
        count += 1
    
    for i in hyper_params['periods_meta']:
        pFixed[str(count)+'meta_feature'] = pFixedC.rolling(i).std()
        count += 1

    return pFixed.dropna()
  
def test_model_one_direction( 
               result: list, 
               stop: float, 
               take: float, 
               forward: float, 
               backward: float, 
               markup: float,
               direction: str, 
               plt = False):
    
    pr_tst = get_features(get_prices())
    X = pr_tst[pr_tst.columns[1:]]
    X_meta = X.copy()
    X = X.loc[:, ~X.columns.str.contains('meta_feature')]
    X_meta = X_meta.loc[:, X_meta.columns.str.contains('meta_feature')]

    pr_tst['labels'] = result[0].predict_proba(X)[:,1]
    pr_tst['meta_labels'] = result[1].predict_proba(X_meta)[:,1]
    pr_tst['labels'] = pr_tst['labels'].apply(lambda x: 0.0 if x < 0.5 else 1.0)
    pr_tst['meta_labels'] = pr_tst['meta_labels'].apply(lambda x: 0.0 if x < 0.5 else 1.0)
    return tester_one_direction(pr_tst, stop, take, forward, backward, markup, direction, plt)

def clustering(dataset, n_clusters: int) -> pd.DataFrame:
    data = dataset[(dataset.index < hyper_params['forward']) & (dataset.index > hyper_params['backward'])].copy()
    meta_X = data.loc[:, data.columns.str.contains('meta_feature')]
    data['clusters'] = KMeans(n_clusters=n_clusters).fit(meta_X).labels_
    return data

def sliding_window_clustering(dataset, n_clusters: int, window_size=200) -> pd.DataFrame:
    import numpy as np
    
    data = dataset[(dataset.index < hyper_params['forward']) & (dataset.index > hyper_params['backward'])].copy()
    meta_X = data.loc[:, data.columns.str.contains('meta_feature')]
    
    # Сначала создаем глобальные эталонные центроиды
    global_kmeans = KMeans(n_clusters=n_clusters).fit(meta_X)
    global_centroids = global_kmeans.cluster_centers_
    
    clusters = np.zeros(len(data))
    
    # Применяем кластеризацию в скользящем окне
    for i in range(0, len(data) - window_size + 1, window_size):
        window_data = meta_X.iloc[i:i+window_size]
        
        # Обучаем KMeans на текущем окне
        local_kmeans = KMeans(n_clusters=n_clusters).fit(window_data)
        local_centroids = local_kmeans.cluster_centers_
        
        # Сопоставляем локальные центроиды с глобальными
        # для обеспечения согласованности меток кластеров
        centroid_mapping = {}
        for local_idx in range(n_clusters):
            # Находим ближайший глобальный центроид к данному локальному
            distances = np.linalg.norm(local_centroids[local_idx] - global_centroids, axis=1)
            global_idx = np.argmin(distances)
            centroid_mapping[local_idx] = global_idx  # +1 для начала нумерации с 1
        
        # Получаем метки для текущего окна
        local_labels = local_kmeans.predict(window_data)
        
        # Преобразуем локальные метки в согласованные глобальные метки
        for j in range(window_size):
            if i+j < len(clusters):  # Проверка на выход за границы
                clusters[i+j] = centroid_mapping[local_labels[j]]
    
    data['clusters'] = clusters
    return data

def markov_regime_switching(dataset, n_regimes: int, model_type="GMMHMM", n_iter = 10) -> pd.DataFrame:
    data = dataset[(dataset.index < hyper_params['forward']) & (dataset.index > hyper_params['backward'])].copy()
    
    # Extract meta features
    meta_X = data.loc[:, data.columns.str.contains('meta_feature')]
    
    if meta_X.shape[1] > 0:
        # Format data for HMM (requires 2D array)
        X = meta_X.values
        # Features normalization before training
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Create and train the HMM model
        if model_type == "HMM":
            model = hmm.GaussianHMM(
                n_components=n_regimes,
                covariance_type="full",
                n_iter=n_iter,
            )
        elif model_type == "GMMHMM":
            model = hmm.GMMHMM(
                n_components=n_regimes,
                covariance_type="full",
                n_iter=n_iter,
                n_mix=2,
            )
        elif model_type == "VARHMM":
            model = vhmm.VariationalGaussianHMM(
                n_components=n_regimes,
                covariance_type="full",
                n_iter=n_iter,
            )
           
        # Fit the model
        model.fit(X_scaled)     
        # Predict the hidden states (regimes)
        hidden_states = model.predict(X_scaled)
        # Assign states to clusters
        data['clusters'] = hidden_states
        
    return data

def markov_regime_switching_prior(dataset, n_regimes: int, model_type="HMM", n_iter=100) -> pd.DataFrame:
    data = dataset[(dataset.index < hyper_params['forward']) & (dataset.index > hyper_params['backward'])].copy()
    
    # Extract meta features
    meta_X = data.loc[:, data.columns.str.contains('meta_feature')]
    
    if meta_X.shape[1] > 0:
        # Format data for HMM (requires 2D array)
        X = meta_X.values
        
        # Features normalization before training
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Calculate priors from meta_features using k-means clustering
        from sklearn.cluster import KMeans
        
        # Use k-means to cluster the data into n_regimes groups
        kmeans = KMeans(n_clusters=n_regimes, n_init=10)
        cluster_labels = kmeans.fit_predict(X_scaled)
        
        # Calculate cluster-specific means and covariances to use as priors
        prior_means = kmeans.cluster_centers_  # Shape: (n_regimes, n_features)
        
        # Calculate empirical covariance for each cluster
        from sklearn.covariance import empirical_covariance
        prior_covs = []
        
        for i in range(n_regimes):
            cluster_data = X_scaled[cluster_labels == i]
            if len(cluster_data) > 1:  # Need at least 2 points for covariance
                cluster_cov = empirical_covariance(cluster_data)
                prior_covs.append(cluster_cov)
            else:
                # Fallback to overall covariance if cluster is too small
                prior_covs.append(empirical_covariance(X_scaled))
        
        prior_covs = np.array(prior_covs)  # Shape: (n_regimes, n_features, n_features)
        
        # Calculate initial state distribution from cluster proportions
        initial_probs = np.bincount(cluster_labels, minlength=n_regimes) / len(cluster_labels)
        
        # Calculate transition matrix based on cluster sequences
        trans_mat = np.zeros((n_regimes, n_regimes))
        for t in range(1, len(cluster_labels)):
            trans_mat[cluster_labels[t-1], cluster_labels[t]] += 1
            
        # Normalize rows to get probabilities
        row_sums = trans_mat.sum(axis=1, keepdims=True)
        # Avoid division by zero
        row_sums[row_sums == 0] = 1
        trans_mat = trans_mat / row_sums
    
        # Initialize model parameters based on model type
        if model_type == "HMM":
            model_params = {
                'n_components': n_regimes,
                'covariance_type': "full",
                'n_iter': n_iter,
                'init_params': ''  # Don't use default initialization
            }
            
            from hmmlearn import hmm
            model = hmm.GaussianHMM(**model_params)
            
            # Set the model parameters directly with our k-means derived priors
            model.startprob_ = initial_probs
            model.transmat_ = trans_mat
            model.means_ = prior_means
            model.covars_ = prior_covs
            
        elif model_type == "GMMHMM":
            model_params = {
                'n_components': n_regimes,
                'covariance_type': "full",
                'n_iter': n_iter,
                'n_mix': 2,
                'init_params': '',
            }
            
            from hmmlearn import hmm
            model = hmm.GMMHMM(**model_params)
            
            # Set initial state distribution and transition matrix
            model.startprob_ = initial_probs
            model.transmat_ = trans_mat
            
            # For GMMHMM, means_ should have shape (n_components, n_mix, n_features)
            # Currently prior_means has shape (n_regimes, n_features)
            # We need to reshape it properly for GMMHMM
            
            n_features = X_scaled.shape[1]
            n_mix = model_params['n_mix']
            
            # Initialize means with a different mixture mean for each component
            # Creating n_mix variations of each cluster center
            model.means_ = np.zeros((n_regimes, n_mix, n_features))
            for i in range(n_regimes):
                for j in range(n_mix):
                    # Add some small variation for different mixtures
                    model.means_[i, j] = prior_means[i] + (j - n_mix/2) * 0.2 * np.std(X_scaled, axis=0)
            
            # Similarly for covariances, shape should be (n_components, n_mix, n_features, n_features)
            model.covars_ = np.zeros((n_regimes, n_mix, n_features, n_features))
            for i in range(n_regimes):
                for j in range(n_mix):
                    model.covars_[i, j] = prior_covs[i] + np.random.rand(n_features, n_features) * 0.01
            
        elif model_type == "VARHMM":
            # For VariationalGaussianHMM
            n_features = X_scaled.shape[1]
            
            model_params = {
                'n_components': n_regimes,
                'covariance_type': "full",
                'n_iter': n_iter,
                'init_params': '',
                # Set priors directly in the parameters
                'means_prior': prior_means,
                'beta_prior': np.ones(n_regimes),  # Shape: (n_components,)
                'dof_prior': np.ones(n_regimes) * (n_features + 2),  # Shape: (n_components,)
                'scale_prior': prior_covs,  # Shape: (n_components, n_features, n_features)
            }
            
            # Create the VARHMM model
            model = vhmm.VariationalGaussianHMM(**model_params)
            
            # Set initial state distribution and transition matrix
            model.startprob_ = initial_probs
            model.transmat_ = trans_mat

                        
            # Fit the model
        model.fit(X_scaled)
        
        # Predict the hidden states (regimes)
        hidden_states = model.predict(X_scaled)
        
        # Assign states to clusters
        data['clusters'] = hidden_states    
        return data
    
    return data

def fit_final_models(clustered, meta) -> list:
    # features for model\meta models. We learn main model only on filtered labels 
    X, X_meta = clustered[clustered.columns[:-1]], meta[meta.columns[:-1]]
    X = X.loc[:, ~X.columns.str.contains('meta_feature')]
    X_meta = X_meta.loc[:, X_meta.columns.str.contains('meta_feature')]
    
    # labels for model\meta models
    y = clustered['labels']
    y_meta = meta['clusters']
    
    y = y.astype('int16')
    y_meta = y_meta.astype('int16')

    # train\test split
    train_X, test_X, train_y, test_y = train_test_split(
        X, y, train_size=0.8, test_size=0.2, shuffle=True)
    
    train_X_m, test_X_m, train_y_m, test_y_m = train_test_split(
        X_meta, y_meta, train_size=0.8, test_size=0.2, shuffle=True)


    # learn main model with train and validation subsets
    model = CatBoostClassifier(iterations=500,
                               custom_loss=['Accuracy'],
                               eval_metric='Accuracy',
                               verbose=False,
                               use_best_model=True,
                               task_type='CPU',
                               thread_count=-1)
    model.fit(train_X, train_y, eval_set=(test_X, test_y),
              early_stopping_rounds=25, plot=False)
    
    # learn meta model with train and validation subsets
    meta_model = CatBoostClassifier(iterations=500,
                                    custom_loss=['F1'],
                                    eval_metric='F1',
                                    verbose=False,
                                    use_best_model=True,
                                    task_type='CPU',
                                    thread_count=-1)
    meta_model.fit(train_X_m, train_y_m, eval_set=(test_X_m, test_y_m),
              early_stopping_rounds=25, plot=False)

    
    R2 = test_model_one_direction([model, meta_model],
                                hyper_params['stop_loss'], 
                                hyper_params['take_profit'],
                                hyper_params['full forward'],
                                hyper_params['backward'],
                                hyper_params['markup'],
                                hyper_params['direction'],
                                plt=False)
    if math.isnan(R2):
        R2 = -1.0
        print('R2 is fixed to -1.0')
    print('R2: ' + str(R2))

    return [R2, model, meta_model]


hyper_params = {
    'symbol': 'XAUUSD_H1',
    'export_path': '/Users/dmitrievsky/Library/Containers/com.isaacmarovitz.Whisky/Bottles/54CFA88F-36A3-47F7-915A-D09B24E89192/drive_c/Program Files/MetaTrader 5/MQL5/Include/Trend following/',
    'model_number': 0,
    'markup': 0.2,
    'stop_loss':  20.000,
    'take_profit': 5.000,
    'periods': [i for i in range(5, 300, 30)],
    'periods_meta': [5],
    'backward': datetime(2000, 1, 1),
    'forward': datetime(2024, 1, 1),
    'full forward': datetime(2026, 1, 1),
    'direction': 'buy',
    'n_clusters': 5,
}


# LEARNING LOOP
dataset = get_features(get_prices())
models = []

for i in range(1):
    # data = clustering(dataset, n_clusters=hyper_params['n_clusters'])
    # data = sliding_window_clustering(dataset, n_clusters=hyper_params['n_clusters'], window_size=20)
    data = markov_regime_switching_prior(dataset, n_regimes=hyper_params['n_clusters'], model_type="HMM", n_iter=100)
    sorted_clusters = data['clusters'].unique()
    sorted_clusters.sort()
    for clust in sorted_clusters:
        clustered_data = data[data['clusters'] == clust].copy()
        if len(clustered_data) < 500:
            print('too few samples: {}'.format(len(clustered_data)))
            continue
    
        clustered_data = get_labels_one_direction(clustered_data,
                                       markup=hyper_params['markup'],
                                       min=1,
                                       max=15,
                                       direction=hyper_params['direction'])
        
        print(f'Iteration: {i}, Cluster: {clust}')
        clustered_data = clustered_data.drop(['close', 'clusters'], axis=1)

        meta_data = data.copy()
        meta_data['clusters'] = meta_data['clusters'].apply(lambda x: 1 if x == clust else 0)
        models.append(fit_final_models(clustered_data, meta_data.drop(['close'], axis=1)))


# TESTING & EXPORT
models.sort(key=lambda x: x[0])
test_model_one_direction(models[-1][1:],
                        hyper_params['stop_loss'],
                        hyper_params['take_profit'],
                        hyper_params['forward'],
                        hyper_params['backward'],
                        hyper_params['markup'],
                        hyper_params['direction'],
                        plt=True)

export_model_to_ONNX(model = models[-1],
                     symbol = hyper_params['symbol'],
                     periods = hyper_params['periods'],
                     periods_meta = hyper_params['periods_meta'],
                     model_number = hyper_params['model_number'],
                     export_path = hyper_params['export_path'])


models[-1][2].get_best_score()['validation']